{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "name": "VMAS: Use vmas environment.ipynb",
   "provenance": [],
   "collapsed_sections": [
    "0NsC_EwfCF5I"
   ]
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "language_info": {
   "name": "python"
  },
  "accelerator": "GPU",
  "gpuClass": "standard"
 },
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "## Initialization"
   ],
   "metadata": {
    "id": "0NsC_EwfCF5I"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "#@title\n",
    "! git clone https://github.com/proroklab/VectorizedMultiAgentSimulator.git\n",
    "%cd /content/VectorizedMultiAgentSimulator\n",
    "!pip install -e ."
   ],
   "metadata": {
    "id": "zjnXLxaOMLuv"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "#@title\n",
    "!sudo apt-get update\n",
    "!sudo apt-get install python3-opengl xvfb\n",
    "!pip install pyvirtualdisplay\n",
    "import pyvirtualdisplay\n",
    "display = pyvirtualdisplay.Display(visible=False, size=(1400, 900))\n",
    "display.start()"
   ],
   "metadata": {
    "id": "1ZpWFjvHOpZJ"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Run\n"
   ],
   "metadata": {
    "id": "jAAA3DXGCLkF"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "from vmas.simulator.scenario import BaseScenario\n",
    "from typing import Union\n",
    "import time\n",
    "import torch\n",
    "from vmas import make_env\n",
    "from vmas.simulator.core import Agent\n",
    "\n",
    "def _get_deterministic_action(agent: Agent, continuous: bool, env):\n",
    "    if continuous:\n",
    "        action = -agent.action.u_range_tensor.expand(env.batch_dim, agent.action_size)\n",
    "    else:\n",
    "        action = (\n",
    "            torch.tensor([1], device=env.device, dtype=torch.long)\n",
    "            .unsqueeze(-1)\n",
    "            .expand(env.batch_dim, 1)\n",
    "        )\n",
    "    return action.clone()\n",
    "\n",
    "def use_vmas_env(\n",
    "    render: bool,\n",
    "    num_envs: int,\n",
    "    n_steps: int,\n",
    "    device: str,\n",
    "    scenario: Union[str, BaseScenario],\n",
    "    continuous_actions: bool,\n",
    "    random_action: bool,\n",
    "    **kwargs\n",
    "):\n",
    "    \"\"\"Example function to use a vmas environment.\n",
    "    \n",
    "    This is a simplification of the function in `vmas.examples.use_vmas_env.py`.\n",
    "\n",
    "    Args:\n",
    "        continuous_actions (bool): Whether the agents have continuous or discrete actions\n",
    "        scenario (str): Name of scenario\n",
    "        device (str): Torch device to use\n",
    "        render (bool): Whether to render the scenario\n",
    "        num_envs (int): Number of vectorized environments\n",
    "        n_steps (int): Number of steps before returning done\n",
    "        random_action (bool): Use random actions or have all agents perform the down action\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    scenario_name = scenario if isinstance(scenario,str) else scenario.__class__.__name__\n",
    "\n",
    "    env = make_env(\n",
    "        scenario=scenario,\n",
    "        num_envs=num_envs,\n",
    "        device=device,\n",
    "        continuous_actions=continuous_actions,\n",
    "        seed=0,\n",
    "        # Environment specific variables\n",
    "        **kwargs\n",
    "    )\n",
    "\n",
    "    frame_list = []  # For creating a gif\n",
    "    init_time = time.time()\n",
    "    step = 0\n",
    "\n",
    "    for s in range(n_steps):\n",
    "        step += 1\n",
    "        print(f\"Step {step}\")\n",
    "\n",
    "        actions = []\n",
    "        for i, agent in enumerate(env.agents):\n",
    "            if not random_action:\n",
    "                action = _get_deterministic_action(agent, continuous_actions, env)\n",
    "            else:\n",
    "                action = env.get_random_action(agent)\n",
    "\n",
    "            actions.append(action)\n",
    "\n",
    "        obs, rews, dones, info = env.step(actions)\n",
    "\n",
    "        if render:\n",
    "            frame = env.render(\n",
    "                mode=\"rgb_array\",\n",
    "                agent_index_focus=None,  # Can give the camera an agent index to focus on\n",
    "            )\n",
    "            frame_list.append(frame)\n",
    "\n",
    "    total_time = time.time() - init_time\n",
    "    print(\n",
    "        f\"It took: {total_time}s for {n_steps} steps of {num_envs} parallel environments on device {device} \"\n",
    "        f\"for {scenario_name} scenario.\"\n",
    "    )\n",
    "\n",
    "    if render:\n",
    "        from moviepy.editor import ImageSequenceClip\n",
    "        fps=30\n",
    "        clip = ImageSequenceClip(frame_list, fps=fps)\n",
    "        clip.write_gif(f'{scenario_name}.gif', fps=fps)"
   ],
   "metadata": {
    "id": "2Ol4AFeRQ3Ma"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "scenario_name=\"waterfall\"\n",
    "use_vmas_env(\n",
    "    scenario=scenario_name,\n",
    "    render=True,\n",
    "    num_envs=32,\n",
    "    n_steps=100,\n",
    "    device=\"cuda\",\n",
    "    continuous_actions=False,\n",
    "    random_action=False,\n",
    "    # Environment specific variables\n",
    "    n_agents=4,\n",
    ")"
   ],
   "metadata": {
    "id": "3cskWki-O8Ul"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "from IPython.display import Image\n",
    "Image(f'{scenario_name}.gif')"
   ],
   "metadata": {
    "id": "UPRa91hMPU1n"
   },
   "execution_count": null,
   "outputs": []
  }
 ]
}
